# -*- coding: utf-8 -*-
"""
Fit the FACs and FAMACs absorption coefficient with Elliott Model
For Jongchul's TPC paper
Using the approach of Chris Davies in DOI: 10.1038/s41467-017-02670-2
Modified by multiplying by E**1/2 in the prefactor
Using LMFIT

@update: 9/5/2020
@author: wenger
"""

import numpy as np
import matplotlib.pyplot as plt
import pymc3 as pm
from scipy import signal
from scipy.stats import lognorm
from lmfit import Model


" Functions "

#  BROADENING FUNCTIONS
        
# Broadening Functions
def gauss(x, sig_g, mu):

	return np.exp(-((x-mu)/sig_g)**2)



def lorentzian(x, sig_L, mu):
	# normalized lorenzian
	return (sig_L**2/np.pi)/((x-mu)**2+sig_L**2)



def asymm_peak(x,sig_L, sig_g, mu):

	return np.where(x > mu, lorentzian(x, sig_L, mu), gauss(x, sig_g, mu))



def logistic(x, A, sig,k, mu):

	return A/(1+np.exp(-sig*k*(x-mu)))


# Excitonic Part
def exciton(x, param, mu_all ):

	A, A2, k, bw_cont_fine, u, sig_exc, asymm_L, sig_L, two_dim, delta = param
	E_ex, split, Eg, Ew = mu_all


	mu1 = Eg - E_ex
	A_ex1 = (4 * np.pi * E_ex ** (3 / 2))

	if two_dim == True:
		mu1a = mu1 - 0.5 * split
		mu1b = mu1 + 0.5 * split
		sig_L = asymm_L*sig_exc
		alpha_mu1 = A_ex1*(gauss(x,sig_exc, mu1a) + asymm_peak(x, sig_L, sig_exc, mu1b))


	else:
		alpha_mu1 = A_ex1*gauss(x,sig_exc, mu1)


	alpha_mu_rest = np.zeros(len(x))


	for m in range(2,5):
		mu = Eg - E_ex/m**2

		alpha_mu_rest += u*(4 * np.pi * E_ex ** (3 / 2) / m ** 3) *gauss(x, sig_exc, mu)
		print(mu)

	return alpha_mu1 + alpha_mu_rest



# Continuum Part
def continuum(x, params, mu_all):

	A, A2, k, bw_cont_fine, u, sig_exc, asymm_L, sig_L, two_dim, delta = param
	E_ex, split, Eg, Ew = mu_all

	cont_1 = logistic(x, A2, sig_exc, k, Eg)

	cont_2 = x*0

	for l in range (1,4):

		Ew = Ew + 2*(l-1)*delta
		cont_2 += gauss(x, sig_exc, Ew)

	return cont_1 #+ cont_2




x = np.linspace(1.5,3.5,500)

# Parameters
E_ex = 0.2
split = 0.01
u = 1

Eg = 2.8
Ew = 2.9		# = B

bw_cont_fine = 0.05
				# dielectric correction factor
sig_exc = 0.1
asymm_L = 0.1
sig_L = 0.02
k = 300
A = 1
A2 = 2			# = h
two_dim = True
delta = 0.1		# = beep

param = (A, A2, k, bw_cont_fine, u, sig_exc, asymm_L, sig_L, two_dim, delta)

mu_all = (E_ex, split, Eg, Ew)


z = exciton(x, param, mu_all) + continuum(x, param, mu_all)




plt.plot(x, z)
plt.yscale("log")
plt.show()


"""

# Cosh function
def fb_cosh(x, sig, mu):
	return 1 / np.cosh((x-mu)/sig)

# Lognorm broadening function (same as Davies)
def fbroad(x, sig, mu, sig0):
    x0 = np.linspace(-2, 2, x.size)
    # https://stackoverflow.com/questions/47136595/scipy-convolve-depends-on-x?rq=1
    def g(x, x0, sig, mu, sig0):
        return signal.convolve(lognorm.pdf(x0, sig0, loc=mu), gauss(x, sig, 0, norm=0), mode='same')
    g_x = g(x, x0, sig, mu, sig0)
    argE = x[np.argmax(g_x)]
    a = g(x+argE, x0, sig, mu, sig0)/np.trapz(g_x, x)
    return a


# ELLIOTT MODEL (with Pcv)

# Excitonic part
# --------------

def alpha_x(E, p, fb=0):
    [b0, Ex, Eg, sigma_exc, mu_lognorm, sigma_cont] = p

    alpha = np.zeros(len(E))
    for n in np.arange(1,11):
        mu = Eg-Ex/n**2

        # Select broadening function
        if fb == 0: # Gaussian broadening
            bf = gauss(E, sigma_exc, mu, norm=1)
        elif fb == 1: # lognorm broadening
            bf = fbroad(E-mu, sigma_exc, mu_lognorm, sigma_cont)/np.trapz(fbroad(E-mu, sigma_exc, mu_lognorm, sigma_cont), E-mu)
        elif fb == 2: # cosh broadening
            bf = fb_cosh(E, sigma_exc, mu)

        alpha += (4 * np.pi * Ex**(3/2) / n**3) * bf # the Delta function is the identity in a convolution
    return alpha
 

# Continuum part
# # --------------

def alpha_c(E, p):
    [b0, Ex, Eg, sigma_exc, mu_lognorm, sigma_cont] = p
    
    x = np.where(E > Eg, np.sqrt(Ex / (E - Eg), where=E>Eg), 0)
    xi = 2 * np.pi * x / (1 - np.exp(-2 * np.pi * x, where=x!=0)) # Sommerfeld factor
    free = np.where(E > Eg, np.sqrt(E-Eg, where=E>Eg), 0) # free continuum without exciton
    return np.where(E > Eg, xi * free, 0)

def alpha_c_conv(E, Erange, p, fb=0): # Broadening of the continuum part
    [b0, Ex, Eg, sigma_exc, mu_lognorm, sigma_cont] = p
    
    # build kernel
    dx = E[1] - E[0] # works with evenly distributed abscissa
    kernel = np.arange(-8*sigma_exc, 8*sigma_exc, dx)
    E_pad = np.pad(E, (E.size, E.size), 'edge')
    
    # Select broadening function
    if fb == 0: # Gaussian broadening
        res = np.convolve(alpha_c(E_pad, p), gauss(kernel, sigma_cont, 0, 1), mode='same')/np.sum(gauss(kernel, sigma_cont, 0, norm=1))
    elif fb == 1: # lognorm broadening
        res = signal.convolve(alpha_c(E_pad, p), fbroad(kernel, sigma_cont, mu_lognorm, sigma_cont), mode='same')/np.sum(fbroad(kernel, sigma_cont, mu_lognorm, sigma_cont))
    elif fb == 2: # cosh broadening
        res = signal.convolve(alpha_c(E_pad, p), fb_cosh(kernel, sigma_cont, 0), mode='same')/np.sum(fb_cosh(kernel, sigma_cont, 0))
    
    res = res[np.max(np.where(E_pad <= Erange[0])): np.min(np.where(E_pad >= Erange[1]))+1]
    # division by the sum is required to keep the scale
    return res

def alpha_sum(E, Erange, b0, Ex, Eg, sigma_exc, mu_lognorm, sigma_cont, fb=0):
    p = [b0, Ex, Eg, sigma_exc, mu_lognorm, sigma_cont]

    Esub = E[np.max(np.where(E <= Erange[0])): np.min(np.where(E >= Erange[1]))+1]

    return b0 * Ex**0.5 / Esub * (alpha_x(Esub, p, fb) + alpha_c_conv(Esub, Erange, p, fb))

def dofit(eV, abscoef, guess, eVrange, fb=0):
    # need to have evenly spaced data for the convolution
    eV_interp = np.linspace(eV[0], eV[-1], 600)
    data_interp = np.interp(eV_interp, eV, absCoef)

    # Fit range
    eVsub = eV_interp[np.max(np.where(eV_interp < eVrange[0])): np.min(np.where(eV_interp > eVrange[1]))+1]
    data = data_interp[np.max(np.where(eV_interp < eVrange[0])): np.min(np.where(eV_interp > eVrange[1]))+1]
    print(fb)
    # Using LMfit
    myMod = Model(alpha_sum, independent_vars=['E', 'Erange'])
    print(myMod.param_names)
    myMod.set_param_hint('b0',value=guess[0])
    myMod.set_param_hint('Ex', value=guess[1], min=0., vary=True)
    myMod.set_param_hint('Eg', value=guess[2], min=0., vary=True)
    myMod.set_param_hint('sigma_exc', value=guess[3], vary=False)#, min=0)
    myMod.set_param_hint('sigma_cont', value=guess[4], min=0.0001,vary=True)
    myMod.set_param_hint('mu_lognorm', value=guess[5], vary=True)
    myMod.set_param_hint('fb', value=int(fb), vary=False)
    
    result = myMod.fit(data, E=eV_interp, Erange=eVrange)
    print(result.fit_report())

    # Plotting results
    # ----------------
    best_params = result.best_values
    bp = [best_params['b0'], best_params['Ex'], best_params['Eg'], best_params['sigma_exc'], best_params['mu_lognorm'], best_params['sigma_cont']]

    fit_exc = bp[0] * bp[1]**0.5 / eV_interp * alpha_x(eV_interp, bp, fb)
    fit_cont = bp[0]  * bp[1]**0.5 / eV_interp * alpha_c_conv(eV_interp, [eV_interp[0], eV_interp[-1]], bp, fb)
    fit_y = bp[0] * bp[1]**0.5 / eV_interp * ( alpha_x(eV_interp, bp, fb) + alpha_c_conv(eV_interp, [eV_interp[0], eV_interp[-1]], bp, fb))

    plt.plot(eV, absCoef, label='data', ls=':')
    plt.plot(eV_interp, fit_exc, label='exc')
    plt.plot(eV_interp, fit_cont, label='cont')
    # plt.plot(eVsub, result.best_fit, label='fit')
    plt.plot(eV_interp, fit_y, ls='-', label='fit full')

    plt.xlim([2.0, 3.5])
    plt.ylim([0, 40])
    plt.xlabel('eV'); plt.ylabel('Abs. coef [1e4 * cm-1]')
    plt.ticklabel_format(style='sci', axis='y', scilimits=(0,3))

    plt.legend()
    plt.show()

    np.savetxt('fits.dat', np.c_[eV_interp, data_interp, fit_y*scale, fit_exc*scale, fit_cont*scale])
    

" Data processing "
# load and convert data
loaded = np.loadtxt('PEA.txt', skiprows=1, delimiter='\t', unpack=False)

absCoef = loaded[:, 1] # for scaling but needs to be appropriately calculateds

#absCoef -= 2100 # Remove baseline

scale = 1e6
absCoef /= scale # rescale
eV = loaded[:, 0]


# Guess parameters
b0 = 140e5 #
Ex = 184.3 # meV
Eg = 2.5997 # eV
sigma_exc = 30. # meV
sigma_cont = 60. # meV
mu_lognorm = -0.2 # eV

fit_range = [2.15, 2.8]

guess = [b0 / scale, Ex*1e-3, Eg, sigma_exc*1e-3, sigma_cont*1e-3, mu_lognorm]
dofit(eV, absCoef, guess, fit_range, fb=1)

"""
